{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformer源代码解释之PyTorch篇\n",
    "在阅读完[2.2-图解transformer](./篇章2-Transformer相关原理/2.2-图解transformer.md)之后，希望大家能对transformer各个模块的设计和计算有一个形象的认识，本小节我们基于pytorch来实现一个Transformer，帮助大家进一步学习这个复杂的模型。\n",
    "**章节**\n",
    "\n",
    "- [词嵌入](#embed)\n",
    "- [位置编码](#pos)\n",
    "- [多头注意力](#multihead)\n",
    "- [搭建Transformer](#build)\n",
    "\n",
    "![](./pictures/0-1-transformer-arc.png)\n",
    "\n",
    "图：Transformer结构图"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **<div id='embed'>词嵌入</div>**\n",
    "\n",
    "如上图所示，Transformer图里左边的是Encoder，右边是Decoder部分。Encoder输入源语言序列，Decoder里面输入需要被翻译的语言文本（在训练时）。一个文本常有许多序列组成，常见操作为将序列进行一些预处理（如词切分等）变成列表，一个序列的列表的元素通常为词表中不可切分的最小词，整个文本就是一个大列表，元素为一个一个由序列组成的列表。如一个序列经过切分后变为[\"am\", \"##ro\", \"##zi\", \"meets\", \"his\", \"father\"]，接下来按照它们在词表中对应的索引进行转换，假设结果如[23, 94, 13, 41, 27, 96]。假如整个文本一共100个句子，那么就有100个列表为它的元素，因为每个序列的长度不一，需要设定最大长度，这里不妨设为128，那么将整个文本转换为数组之后，形状即为100 x 128，这就对应着batch_size和seq_length。\n",
    "\n",
    "输入之后，紧接着进行词嵌入处理，词嵌入就是将每一个词用预先训练好的向量进行映射。\n",
    "\n",
    "词嵌入在torch里基于`torch.nn.Embedding`实现，实例化时需要设置的参数为词表的大小和被映射的向量的维度比如`embed = nn.Embedding(10,8)`。向量的维度通俗来说就是向量里面有多少个数。注意，第一个参数是词表的大小，如果你目前最多有8个词，通常填写10（多一个位置留给unk和pad），你后面万一进入与这8个词不同的词就映射到unk上，序列padding的部分就映射到pad上。\n",
    "\n",
    "假如我们打算映射到8维（num_features或者embed_dim），那么，整个文本的形状变为100 x 128 x 8。接下来举个小例子解释一下：假设我们词表一共有10个词(算上unk和pad)，文本里有2个句子，每个句子有4个词，我们想要把每个词映射到8维的向量。于是2，4，8对应于batch_size, seq_length, embed_dim（如果batch在第一维的话）。\n",
    "\n",
    "另外，一般深度学习任务只改变num_features，所以讲维度一般是针对最后特征所在的维度。\n",
    "\n",
    "开始编程：\n",
    "\n",
    "所有需要的包的导入："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch.nn.init import xavier_uniform_\n",
    "from torch.nn.init import constant_\n",
    "from torch.nn.init import xavier_normal_\n",
    "import torch.nn.functional as F\n",
    "from typing import Optional, Tuple, Any\n",
    "from typing import List, Optional, Tuple\n",
    "import math\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 8])\n"
     ]
    }
   ],
   "source": [
    "X = torch.zeros((2,4),dtype=torch.long)\n",
    "embed = nn.Embedding(10,8)\n",
    "print(embed(X).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **<div id='pos'>位置编码</div>**\n",
    "\n",
    "词嵌入之后紧接着就是位置编码，位置编码用以区分不同词以及同词不同特征之间的关系。代码中需要注意：X_只是初始化的矩阵，并不是输入进来的；完成位置编码之后会加一个dropout。另外，位置编码是最后加上去的，因此输入输出形状不变。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Tensor = torch.Tensor\n",
    "def positional_encoding(X, num_features, dropout_p=0.1, max_len=512) -> Tensor:\n",
    "    r'''\n",
    "        给输入加入位置编码\n",
    "    参数：\n",
    "        - num_features: 输入进来的维度\n",
    "        - dropout_p: dropout的概率，当其为非零时执行dropout\n",
    "        - max_len: 句子的最大长度，默认512\n",
    "    \n",
    "    形状：\n",
    "        - 输入： [batch_size, seq_length, num_features]\n",
    "        - 输出： [batch_size, seq_length, num_features]\n",
    "\n",
    "    例子：\n",
    "        >>> X = torch.randn((2,4,10))\n",
    "        >>> X = positional_encoding(X, 10)\n",
    "        >>> print(X.shape)\n",
    "        >>> torch.Size([2, 4, 10])\n",
    "    '''\n",
    "\n",
    "    dropout = nn.Dropout(dropout_p)\n",
    "    P = torch.zeros((1,max_len,num_features))\n",
    "    X_ = torch.arange(max_len,dtype=torch.float32).reshape(-1,1) / torch.pow(\n",
    "        10000,\n",
    "        torch.arange(0,num_features,2,dtype=torch.float32) /num_features)\n",
    "    P[:,:,0::2] = torch.sin(X_)\n",
    "    P[:,:,1::2] = torch.cos(X_)\n",
    "    X = X + P[:,:X.shape[1],:].to(X.device)\n",
    "    return dropout(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 10])\n"
     ]
    }
   ],
   "source": [
    "# 位置编码例子\n",
    "X = torch.randn((2,4,10))\n",
    "X = positional_encoding(X, 10)\n",
    "print(X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **<div id='multihead'>多头注意力</div>**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 拆开看多头注意力机制\n",
    "**完整版本可运行的多头注意里机制的class在后面，先看一下完整的： 多头注意力机制-MultiheadAttention 小节再回来依次看下面的解释。**\n",
    "\n",
    "多头注意力类主要成分是：参数初始化、multi_head_attention_forward\n",
    "\n",
    "#### 初始化参数\n",
    "```python\n",
    "if self._qkv_same_embed_dim is False:\n",
    "    # 初始化前后形状维持不变\n",
    "    # (seq_length x embed_dim) x (embed_dim x embed_dim) ==> (seq_length x embed_dim)\n",
    "    self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim)))\n",
    "    self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim)))\n",
    "    self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim)))\n",
    "    self.register_parameter('in_proj_weight', None)\n",
    "else:\n",
    "    self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))\n",
    "    self.register_parameter('q_proj_weight', None)\n",
    "    self.register_parameter('k_proj_weight', None)\n",
    "    self.register_parameter('v_proj_weight', None)\n",
    "\n",
    "if bias:\n",
    "    self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n",
    "else:\n",
    "    self.register_parameter('in_proj_bias', None)\n",
    "# 后期会将所有头的注意力拼接在一起然后乘上权重矩阵输出\n",
    "# out_proj是为了后期准备的\n",
    "self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n",
    "self._reset_parameters()\n",
    "```\n",
    "\n",
    "torch.empty是按照所给的形状形成对应的tensor，特点是填充的值还未初始化，类比torch.randn（标准正态分布），这就是一种初始化的方式。在PyTorch中，变量类型是tensor的话是无法修改值的，而Parameter()函数可以看作为一种类型转变函数，将不可改值的tensor转换为可训练可修改的模型参数，即与model.parameters绑定在一起，register_parameter的意思是是否将这个参数放到model.parameters，None的意思是没有这个参数。\n",
    "\n",
    "这里有个if判断，用以判断q,k,v的最后一维是否一致，若一致，则一个大的权重矩阵全部乘然后分割出来，若不是，则各初始化各的，其实初始化是不会改变原来的形状的（如![](http://latex.codecogs.com/svg.latex?q=qW_q+b_q)，见注释）。\n",
    "\n",
    "可以发现最后有一个_reset_parameters()函数，这个是用来初始化参数数值的。xavier_uniform意思是从[连续型均匀分布](https://zh.wikipedia.org/wiki/%E9%80%A3%E7%BA%8C%E5%9E%8B%E5%9D%87%E5%8B%BB%E5%88%86%E5%B8%83)里面随机取样出值来作为初始化的值，xavier_normal_取样的分布是正态分布。正因为初始化值在训练神经网络的时候很重要，所以才需要这两个函数。\n",
    "\n",
    "constant_意思是用所给值来填充输入的向量。\n",
    "\n",
    "另外，在PyTorch的源码里，似乎projection代表是一种线性变换的意思，in_proj_bias的意思就是一开始的线性变换的偏置\n",
    "\n",
    "```python\n",
    "def _reset_parameters(self):\n",
    "    if self._qkv_same_embed_dim:\n",
    "        xavier_uniform_(self.in_proj_weight)\n",
    "    else:\n",
    "        xavier_uniform_(self.q_proj_weight)\n",
    "        xavier_uniform_(self.k_proj_weight)\n",
    "        xavier_uniform_(self.v_proj_weight)\n",
    "    if self.in_proj_bias is not None:\n",
    "        constant_(self.in_proj_bias, 0.)\n",
    "        constant_(self.out_proj.bias, 0.)\n",
    "\n",
    "```\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### multi_head_attention_forward\n",
    "这个函数如下代码所示，主要分成3个部分：\n",
    "- query, key, value通过_in_projection_packed变换得到q,k,v\n",
    "- 遮挡机制\n",
    "- 点积注意力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "Tensor = torch.Tensor\n",
    "def multi_head_attention_forward(\n",
    "    query: Tensor,\n",
    "    key: Tensor,\n",
    "    value: Tensor,\n",
    "    num_heads: int,\n",
    "    in_proj_weight: Tensor,\n",
    "    in_proj_bias: Optional[Tensor],\n",
    "    dropout_p: float,\n",
    "    out_proj_weight: Tensor,\n",
    "    out_proj_bias: Optional[Tensor],\n",
    "    training: bool = True,\n",
    "    key_padding_mask: Optional[Tensor] = None,\n",
    "    need_weights: bool = True,\n",
    "    attn_mask: Optional[Tensor] = None,\n",
    "    use_seperate_proj_weight = None,\n",
    "    q_proj_weight: Optional[Tensor] = None,\n",
    "    k_proj_weight: Optional[Tensor] = None,\n",
    "    v_proj_weight: Optional[Tensor] = None,\n",
    ") -> Tuple[Tensor, Optional[Tensor]]:\n",
    "    r'''\n",
    "    形状：\n",
    "        输入：\n",
    "        - query：`(L, N, E)`\n",
    "        - key: `(S, N, E)`\n",
    "        - value: `(S, N, E)`\n",
    "        - key_padding_mask: `(N, S)`\n",
    "        - attn_mask: `(L, S)` or `(N * num_heads, L, S)`\n",
    "        输出：\n",
    "        - attn_output:`(L, N, E)`\n",
    "        - attn_output_weights:`(N, L, S)`\n",
    "    '''\n",
    "    tgt_len, bsz, embed_dim = query.shape\n",
    "    src_len, _, _ = key.shape\n",
    "    head_dim = embed_dim // num_heads\n",
    "    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)\n",
    "\n",
    "    if attn_mask is not None:\n",
    "        if attn_mask.dtype == torch.uint8:\n",
    "            warnings.warn(\"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n",
    "            attn_mask = attn_mask.to(torch.bool)\n",
    "        else:\n",
    "            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \\\n",
    "                f\"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}\"\n",
    "\n",
    "        if attn_mask.dim() == 2:\n",
    "            correct_2d_size = (tgt_len, src_len)\n",
    "            if attn_mask.shape != correct_2d_size:\n",
    "                raise RuntimeError(f\"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.\")\n",
    "            attn_mask = attn_mask.unsqueeze(0)\n",
    "        elif attn_mask.dim() == 3:\n",
    "            correct_3d_size = (bsz * num_heads, tgt_len, src_len)\n",
    "            if attn_mask.shape != correct_3d_size:\n",
    "                raise RuntimeError(f\"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.\")\n",
    "        else:\n",
    "            raise RuntimeError(f\"attn_mask's dimension {attn_mask.dim()} is not supported\")\n",
    "\n",
    "    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:\n",
    "        warnings.warn(\"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n",
    "        key_padding_mask = key_padding_mask.to(torch.bool)\n",
    "    \n",
    "    # reshape q,k,v将Batch放在第一维以适合点积注意力\n",
    "    # 同时为多头机制，将不同的头拼在一起组成一层\n",
    "    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)\n",
    "    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n",
    "    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n",
    "    if key_padding_mask is not None:\n",
    "        assert key_padding_mask.shape == (bsz, src_len), \\\n",
    "            f\"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}\"\n",
    "        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \\\n",
    "            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)\n",
    "        if attn_mask is None:\n",
    "            attn_mask = key_padding_mask\n",
    "        elif attn_mask.dtype == torch.bool:\n",
    "            attn_mask = attn_mask.logical_or(key_padding_mask)\n",
    "        else:\n",
    "            attn_mask = attn_mask.masked_fill(key_padding_mask, float(\"-inf\"))\n",
    "    # 若attn_mask值是布尔值，则将mask转换为float\n",
    "    if attn_mask is not None and attn_mask.dtype == torch.bool:\n",
    "        new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)\n",
    "        new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n",
    "        attn_mask = new_attn_mask\n",
    "\n",
    "    # 若training为True时才应用dropout\n",
    "    if not training:\n",
    "        dropout_p = 0.0\n",
    "    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)\n",
    "    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)\n",
    "    attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)\n",
    "    if need_weights:\n",
    "        # average attention weights over heads\n",
    "        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n",
    "        return attn_output, attn_output_weights.sum(dim=1) / num_heads\n",
    "    else:\n",
    "        return attn_output, None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### query, key, value通过_in_projection_packed变换得到q,k,v\n",
    "```\n",
    "q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)\n",
    "```\n",
    "\n",
    "对于`nn.functional.linear`函数，其实就是一个线性变换，与`nn.Linear`不同的是，前者可以提供权重矩阵和偏置，执行![](http://latex.codecogs.com/svg.latex?y=xW^T+b)，而后者是可以自由决定输出的维度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _in_projection_packed(\n",
    "    q: Tensor,\n",
    "    k: Tensor,\n",
    "    v: Tensor,\n",
    "    w: Tensor,\n",
    "    b: Optional[Tensor] = None,\n",
    ") -> List[Tensor]:\n",
    "    r\"\"\"\n",
    "    用一个大的权重参数矩阵进行线性变换\n",
    "\n",
    "    参数:\n",
    "        q, k, v: 对自注意来说，三者都是src；对于seq2seq模型，k和v是一致的tensor。\n",
    "                 但它们的最后一维(num_features或者叫做embed_dim)都必须保持一致。\n",
    "        w: 用以线性变换的大矩阵，按照q,k,v的顺序压在一个tensor里面。\n",
    "        b: 用以线性变换的偏置，按照q,k,v的顺序压在一个tensor里面。\n",
    "\n",
    "    形状:\n",
    "        输入:\n",
    "        - q: shape:`(..., E)`，E是词嵌入的维度（下面出现的E均为此意）。\n",
    "        - k: shape:`(..., E)`\n",
    "        - v: shape:`(..., E)`\n",
    "        - w: shape:`(E * 3, E)`\n",
    "        - b: shape:`E * 3` \n",
    "\n",
    "        输出:\n",
    "        - 输出列表 :`[q', k', v']`，q,k,v经过线性变换前后的形状都一致。\n",
    "    \"\"\"\n",
    "    E = q.size(-1)\n",
    "    # 若为自注意，则q = k = v = src，因此它们的引用变量都是src\n",
    "    # 即k is v和q is k结果均为True\n",
    "    # 若为seq2seq，k = v，因而k is v的结果是True\n",
    "    if k is v:\n",
    "        if q is k:\n",
    "            return F.linear(q, w, b).chunk(3, dim=-1)\n",
    "        else:\n",
    "            # seq2seq模型\n",
    "            w_q, w_kv = w.split([E, E * 2])\n",
    "            if b is None:\n",
    "                b_q = b_kv = None\n",
    "            else:\n",
    "                b_q, b_kv = b.split([E, E * 2])\n",
    "            return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)\n",
    "    else:\n",
    "        w_q, w_k, w_v = w.chunk(3)\n",
    "        if b is None:\n",
    "            b_q = b_k = b_v = None\n",
    "        else:\n",
    "            b_q, b_k, b_v = b.chunk(3)\n",
    "        return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)\n",
    "\n",
    "# q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***\n",
    "\n",
    "##### 遮挡机制\n",
    "\n",
    "对于attn_mask来说，若为2D，形状如`(L, S)`，L和S分别代表着目标语言和源语言序列长度，若为3D,形状如`(N * num_heads, L, S)`，N代表着batch_size，num_heads代表注意力头的数目。若为attn_mask的dtype为ByteTensor，非0的位置会被忽略不做注意力；若为BoolTensor，True对应的位置会被忽略；若为数值，则会直接加到attn_weights。\n",
    "\n",
    "因为在decoder解码的时候，只能看该位置和它之前的，如果看后面就犯规了，所以需要attn_mask遮挡住。\n",
    "\n",
    "下面函数直接复制PyTorch的，意思是确保不同维度的mask形状正确以及不同类型的转换\n",
    "\n",
    "\n",
    "```python\n",
    "if attn_mask is not None:\n",
    "    if attn_mask.dtype == torch.uint8:\n",
    "        warnings.warn(\"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n",
    "        attn_mask = attn_mask.to(torch.bool)\n",
    "    else:\n",
    "        assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \\\n",
    "            f\"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}\"\n",
    "    # 对不同维度的形状判定\n",
    "    if attn_mask.dim() == 2:\n",
    "        correct_2d_size = (tgt_len, src_len)\n",
    "        if attn_mask.shape != correct_2d_size:\n",
    "            raise RuntimeError(f\"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.\")\n",
    "            attn_mask = attn_mask.unsqueeze(0)\n",
    "    elif attn_mask.dim() == 3:\n",
    "        correct_3d_size = (bsz * num_heads, tgt_len, src_len)\n",
    "        if attn_mask.shape != correct_3d_size:\n",
    "            raise RuntimeError(f\"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.\")\n",
    "    else:\n",
    "        raise RuntimeError(f\"attn_mask's dimension {attn_mask.dim()} is not supported\")\n",
    "\n",
    "```\n",
    "与`attn_mask`不同的是，`key_padding_mask`是用来遮挡住key里面的值，详细来说应该是`<PAD>`，被忽略的情况与attn_mask一致。\n",
    "\n",
    "```python\n",
    "# 将key_padding_mask值改为布尔值\n",
    "if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:\n",
    "    warnings.warn(\"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n",
    "    key_padding_mask = key_padding_mask.to(torch.bool)\n",
    "```\n",
    "\n",
    "先介绍两个小函数，`logical_or`，输入两个tensor，并对这两个tensor里的值做`逻辑或`运算，只有当两个值均为0的时候才为`False`，其他时候均为`True`，另一个是`masked_fill`，输入是一个mask，和用以填充的值。mask由1，0组成，0的位置值维持不变，1的位置用新值填充。\n",
    "```python\n",
    "a = torch.tensor([0,1,10,0],dtype=torch.int8)\n",
    "b = torch.tensor([4,0,1,0],dtype=torch.int8)\n",
    "print(torch.logical_or(a,b))\n",
    "# tensor([ True,  True,  True, False])\n",
    "```\n",
    "\n",
    "```python\n",
    "r = torch.tensor([[0,0,0,0],[0,0,0,0]])\n",
    "mask = torch.tensor([[1,1,1,1],[0,0,0,0]])\n",
    "print(r.masked_fill(mask,1))\n",
    "# tensor([[1, 1, 1, 1],\n",
    "#         [0, 0, 0, 0]])\n",
    "```\n",
    "其实attn_mask和key_padding_mask有些时候对象是一致的，所以有时候可以合起来看。`-inf`做softmax之后值为0，即被忽略。\n",
    "```python\n",
    "if key_padding_mask is not None:\n",
    "    assert key_padding_mask.shape == (bsz, src_len), \\\n",
    "        f\"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}\"\n",
    "    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \\\n",
    "        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)\n",
    "    # 若attn_mask为空，直接用key_padding_mask\n",
    "    if attn_mask is None:\n",
    "        attn_mask = key_padding_mask\n",
    "    elif attn_mask.dtype == torch.bool:\n",
    "        attn_mask = attn_mask.logical_or(key_padding_mask)\n",
    "    else:\n",
    "        attn_mask = attn_mask.masked_fill(key_padding_mask, float(\"-inf\"))\n",
    "\n",
    "# 若attn_mask值是布尔值，则将mask转换为float\n",
    "if attn_mask is not None and attn_mask.dtype == torch.bool:\n",
    "    new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)\n",
    "    new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n",
    "    attn_mask = new_attn_mask\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***\n",
    "##### 点积注意力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Tuple, Any\n",
    "def _scaled_dot_product_attention(\n",
    "    q: Tensor,\n",
    "    k: Tensor,\n",
    "    v: Tensor,\n",
    "    attn_mask: Optional[Tensor] = None,\n",
    "    dropout_p: float = 0.0,\n",
    ") -> Tuple[Tensor, Tensor]:\n",
    "    r'''\n",
    "    在query, key, value上计算点积注意力，若有注意力遮盖则使用，并且应用一个概率为dropout_p的dropout\n",
    "\n",
    "    参数：\n",
    "        - q: shape:`(B, Nt, E)` B代表batch size， Nt是目标语言序列长度，E是嵌入后的特征维度\n",
    "        - key: shape:`(B, Ns, E)` Ns是源语言序列长度\n",
    "        - value: shape:`(B, Ns, E)`与key形状一样\n",
    "        - attn_mask: 要么是3D的tensor，形状为:`(B, Nt, Ns)`或者2D的tensor，形状如:`(Nt, Ns)`\n",
    "\n",
    "        - Output: attention values: shape:`(B, Nt, E)`，与q的形状一致;attention weights: shape:`(B, Nt, Ns)`\n",
    "    \n",
    "    例子：\n",
    "        >>> q = torch.randn((2,3,6))\n",
    "        >>> k = torch.randn((2,4,6))\n",
    "        >>> v = torch.randn((2,4,6))\n",
    "        >>> out = scaled_dot_product_attention(q, k, v)\n",
    "        >>> out[0].shape, out[1].shape\n",
    "        >>> torch.Size([2, 3, 6]) torch.Size([2, 3, 4])\n",
    "    '''\n",
    "    B, Nt, E = q.shape\n",
    "    q = q / math.sqrt(E)\n",
    "    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)\n",
    "    attn = torch.bmm(q, k.transpose(-2,-1))\n",
    "    if attn_mask is not None:\n",
    "        attn += attn_mask \n",
    "    # attn意味着目标序列的每个词对源语言序列做注意力\n",
    "    attn = F.softmax(attn, dim=-1)\n",
    "    if dropout_p:\n",
    "        attn = F.dropout(attn, p=dropout_p)\n",
    "    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)\n",
    "    output = torch.bmm(attn, v)\n",
    "    return output, attn \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 完整的多头注意力机制-MultiheadAttention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiheadAttention(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        embed_dim: 词嵌入的维度\n",
    "        num_heads: 平行头的数量\n",
    "        batch_first: 若`True`，则为(batch, seq, feture)，若为`False`，则为(seq, batch, feature)\n",
    "    \n",
    "    例子：\n",
    "        >>> multihead_attn = MultiheadAttention(embed_dim, num_heads)\n",
    "        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)\n",
    "    '''\n",
    "    def __init__(self, embed_dim, num_heads, dropout=0., bias=True,\n",
    "                 kdim=None, vdim=None, batch_first=False) -> None:\n",
    "        # factory_kwargs = {'device': device, 'dtype': dtype}\n",
    "        super(MultiheadAttention, self).__init__()\n",
    "        self.embed_dim = embed_dim\n",
    "        self.kdim = kdim if kdim is not None else embed_dim\n",
    "        self.vdim = vdim if vdim is not None else embed_dim\n",
    "        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim\n",
    "\n",
    "        self.num_heads = num_heads\n",
    "        self.dropout = dropout\n",
    "        self.batch_first = batch_first\n",
    "        self.head_dim = embed_dim // num_heads\n",
    "        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n",
    "\n",
    "        if self._qkv_same_embed_dim is False:\n",
    "            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim)))\n",
    "            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim)))\n",
    "            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim)))\n",
    "            self.register_parameter('in_proj_weight', None)\n",
    "        else:\n",
    "            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))\n",
    "            self.register_parameter('q_proj_weight', None)\n",
    "            self.register_parameter('k_proj_weight', None)\n",
    "            self.register_parameter('v_proj_weight', None)\n",
    "\n",
    "        if bias:\n",
    "            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n",
    "        else:\n",
    "            self.register_parameter('in_proj_bias', None)\n",
    "        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)\n",
    "\n",
    "        self._reset_parameters()\n",
    "\n",
    "    def _reset_parameters(self):\n",
    "        if self._qkv_same_embed_dim:\n",
    "            xavier_uniform_(self.in_proj_weight)\n",
    "        else:\n",
    "            xavier_uniform_(self.q_proj_weight)\n",
    "            xavier_uniform_(self.k_proj_weight)\n",
    "            xavier_uniform_(self.v_proj_weight)\n",
    "\n",
    "        if self.in_proj_bias is not None:\n",
    "            constant_(self.in_proj_bias, 0.)\n",
    "            constant_(self.out_proj.bias, 0.)\n",
    "\n",
    "\n",
    "\n",
    "    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,\n",
    "                need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:\n",
    "        if self.batch_first:\n",
    "            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]\n",
    "\n",
    "        if not self._qkv_same_embed_dim:\n",
    "            attn_output, attn_output_weights = multi_head_attention_forward(\n",
    "                query, key, value, self.num_heads,\n",
    "                self.in_proj_weight, self.in_proj_bias,\n",
    "                self.dropout, self.out_proj.weight, self.out_proj.bias,\n",
    "                training=self.training,\n",
    "                key_padding_mask=key_padding_mask, need_weights=need_weights,\n",
    "                attn_mask=attn_mask, use_separate_proj_weight=True,\n",
    "                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,\n",
    "                v_proj_weight=self.v_proj_weight)\n",
    "        else:\n",
    "            attn_output, attn_output_weights = multi_head_attention_forward(\n",
    "                query, key, value, self.num_heads,\n",
    "                self.in_proj_weight, self.in_proj_bias,\n",
    "                self.dropout, self.out_proj.weight, self.out_proj.bias,\n",
    "                training=self.training,\n",
    "                key_padding_mask=key_padding_mask, need_weights=need_weights,\n",
    "                attn_mask=attn_mask)\n",
    "        if self.batch_first:\n",
    "            return attn_output.transpose(1, 0), attn_output_weights\n",
    "        else:\n",
    "            return attn_output, attn_output_weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "接下来可以实践一下，并且把位置编码加起来，可以发现加入位置编码和进行多头注意力的前后形状都是不会变的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 100])\n",
      "torch.Size([2, 4, 100]) torch.Size([4, 2, 2])\n"
     ]
    }
   ],
   "source": [
    "# 因为batch_first为False,所以src的shape：`(seq, batch, embed_dim)`\n",
    "src = torch.randn((2,4,100))\n",
    "src = positional_encoding(src,100,0.1)\n",
    "print(src.shape)\n",
    "multihead_attn = MultiheadAttention(100, 4, 0.1)\n",
    "attn_output, attn_output_weights = multihead_attn(src,src,src)\n",
    "print(attn_output.shape, attn_output_weights.shape)\n",
    "\n",
    "# torch.Size([2, 4, 100])\n",
    "# torch.Size([2, 4, 100]) torch.Size([4, 2, 2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***\n",
    "## **<div id='build'>搭建Transformer</div>**\n",
    "- Encoder Layer\n",
    "\n",
    "![](./pictures/2-2-1-encoder.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerEncoderLayer(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        d_model: 词嵌入的维度（必备）\n",
    "        nhead: 多头注意力中平行头的数目（必备）\n",
    "        dim_feedforward: 全连接层的神经元的数目，又称经过此层输入的维度（Default = 2048）\n",
    "        dropout: dropout的概率（Default = 0.1）\n",
    "        activation: 两个线性层中间的激活函数，默认relu或gelu\n",
    "        lay_norm_eps: layer normalization中的微小量，防止分母为0（Default = 1e-5）\n",
    "        batch_first: 若`True`，则为(batch, seq, feture)，若为`False`，则为(seq, batch, feature)（Default：False）\n",
    "\n",
    "    例子：\n",
    "        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "        >>> src = torch.randn((32, 10, 512))\n",
    "        >>> out = encoder_layer(src)\n",
    "    '''\n",
    "\n",
    "    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,\n",
    "                 layer_norm_eps=1e-5, batch_first=False) -> None:\n",
    "        super(TransformerEncoderLayer, self).__init__()\n",
    "        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)\n",
    "        self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
    "\n",
    "        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "        self.dropout1 = nn.Dropout(dropout)\n",
    "        self.dropout2 = nn.Dropout(dropout)\n",
    "        self.activation = activation        \n",
    "\n",
    "\n",
    "    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n",
    "        src = positional_encoding(src, src.shape[-1])\n",
    "        src2 = self.self_attn(src, src, src, attn_mask=src_mask, \n",
    "        key_padding_mask=src_key_padding_mask)[0]\n",
    "        src = src + self.dropout1(src2)\n",
    "        src = self.norm1(src)\n",
    "        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n",
    "        src = src + self.dropout(src2)\n",
    "        src = self.norm2(src)\n",
    "        return src\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 10, 512])\n"
     ]
    }
   ],
   "source": [
    "# 用小例子看一下\n",
    "encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "src = torch.randn((32, 10, 512))\n",
    "out = encoder_layer(src)\n",
    "print(out.shape)\n",
    "# torch.Size([32, 10, 512])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transformer layer组成Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerEncoder(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        encoder_layer（必备）\n",
    "        num_layers： encoder_layer的层数（必备）\n",
    "        norm: 归一化的选择（可选）\n",
    "    \n",
    "    例子：\n",
    "        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "        >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)\n",
    "        >>> src = torch.randn((10, 32, 512))\n",
    "        >>> out = transformer_encoder(src)\n",
    "    '''\n",
    "\n",
    "    def __init__(self, encoder_layer, num_layers, norm=None):\n",
    "        super(TransformerEncoder, self).__init__()\n",
    "        self.layer = encoder_layer\n",
    "        self.num_layers = num_layers\n",
    "        self.norm = norm\n",
    "    \n",
    "    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n",
    "        output = positional_encoding(src, src.shape[-1])\n",
    "        for _ in range(self.num_layers):\n",
    "            output = self.layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
    "        \n",
    "        if self.norm is not None:\n",
    "            output = self.norm(output)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 32, 512])\n"
     ]
    }
   ],
   "source": [
    "# 例子\n",
    "encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)\n",
    "src = torch.randn((10, 32, 512))\n",
    "out = transformer_encoder(src)\n",
    "print(out.shape)\n",
    "# torch.Size([10, 32, 512])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***\n",
    "## Decoder Layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerDecoderLayer(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        d_model: 词嵌入的维度（必备）\n",
    "        nhead: 多头注意力中平行头的数目（必备）\n",
    "        dim_feedforward: 全连接层的神经元的数目，又称经过此层输入的维度（Default = 2048）\n",
    "        dropout: dropout的概率（Default = 0.1）\n",
    "        activation: 两个线性层中间的激活函数，默认relu或gelu\n",
    "        lay_norm_eps: layer normalization中的微小量，防止分母为0（Default = 1e-5）\n",
    "        batch_first: 若`True`，则为(batch, seq, feture)，若为`False`，则为(seq, batch, feature)（Default：False）\n",
    "    \n",
    "    例子：\n",
    "        >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)\n",
    "        >>> memory = torch.randn((10, 32, 512))\n",
    "        >>> tgt = torch.randn((20, 32, 512))\n",
    "        >>> out = decoder_layer(tgt, memory)\n",
    "    '''\n",
    "    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,\n",
    "                 layer_norm_eps=1e-5, batch_first=False) -> None:\n",
    "        super(TransformerDecoderLayer, self).__init__()\n",
    "        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)\n",
    "        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)\n",
    "\n",
    "        self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
    "\n",
    "        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "        self.dropout1 = nn.Dropout(dropout)\n",
    "        self.dropout2 = nn.Dropout(dropout)\n",
    "        self.dropout3 = nn.Dropout(dropout)\n",
    "\n",
    "        self.activation = activation\n",
    "\n",
    "    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, \n",
    "                memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n",
    "        r'''\n",
    "        参数：\n",
    "            tgt: 目标语言序列（必备）\n",
    "            memory: 从最后一个encoder_layer跑出的句子（必备）\n",
    "            tgt_mask: 目标语言序列的mask（可选）\n",
    "            memory_mask（可选）\n",
    "            tgt_key_padding_mask（可选）\n",
    "            memory_key_padding_mask（可选）\n",
    "        '''\n",
    "        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,\n",
    "                              key_padding_mask=tgt_key_padding_mask)[0]\n",
    "        tgt = tgt + self.dropout1(tgt2)\n",
    "        tgt = self.norm1(tgt)\n",
    "        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,\n",
    "                                   key_padding_mask=memory_key_padding_mask)[0]\n",
    "        tgt = tgt + self.dropout2(tgt2)\n",
    "        tgt = self.norm2(tgt)\n",
    "        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n",
    "        tgt = tgt + self.dropout3(tgt2)\n",
    "        tgt = self.norm3(tgt)\n",
    "        return tgt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 32, 512])\n"
     ]
    }
   ],
   "source": [
    "# 可爱的小例子\n",
    "decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)\n",
    "memory = torch.randn((10, 32, 512))\n",
    "tgt = torch.randn((20, 32, 512))\n",
    "out = decoder_layer(tgt, memory)\n",
    "print(out.shape)\n",
    "# torch.Size([20, 32, 512])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerDecoder(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        decoder_layer（必备）\n",
    "        num_layers: decoder_layer的层数（必备）\n",
    "        norm: 归一化选择\n",
    "    \n",
    "    例子：\n",
    "        >>> decoder_layer =TransformerDecoderLayer(d_model=512, nhead=8)\n",
    "        >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)\n",
    "        >>> memory = torch.rand(10, 32, 512)\n",
    "        >>> tgt = torch.rand(20, 32, 512)\n",
    "        >>> out = transformer_decoder(tgt, memory)\n",
    "    '''\n",
    "    def __init__(self, decoder_layer, num_layers, norm=None):\n",
    "        super(TransformerDecoder, self).__init__()\n",
    "        self.layer = decoder_layer\n",
    "        self.num_layers = num_layers\n",
    "        self.norm = norm\n",
    "    \n",
    "    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,\n",
    "                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,\n",
    "                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n",
    "        output = tgt\n",
    "        for _ in range(self.num_layers):\n",
    "            output = self.layer(output, memory, tgt_mask=tgt_mask,\n",
    "                         memory_mask=memory_mask,\n",
    "                         tgt_key_padding_mask=tgt_key_padding_mask,\n",
    "                         memory_key_padding_mask=memory_key_padding_mask)\n",
    "        if self.norm is not None:\n",
    "            output = self.norm(output)\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 32, 512])\n"
     ]
    }
   ],
   "source": [
    "# 可爱的小例子\n",
    "decoder_layer =TransformerDecoderLayer(d_model=512, nhead=8)\n",
    "transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)\n",
    "memory = torch.rand(10, 32, 512)\n",
    "tgt = torch.rand(20, 32, 512)\n",
    "out = transformer_decoder(tgt, memory)\n",
    "print(out.shape)\n",
    "# torch.Size([20, 32, 512])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结一下，其实经过位置编码，多头注意力，Encoder Layer和Decoder Layer形状不会变的，而Encoder和Decoder分别与src和tgt形状一致"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    r'''\n",
    "    参数：\n",
    "        d_model: 词嵌入的维度（必备）（Default=512）\n",
    "        nhead: 多头注意力中平行头的数目（必备）（Default=8）\n",
    "        num_encoder_layers:编码层层数（Default=8）\n",
    "        num_decoder_layers:解码层层数（Default=8）\n",
    "        dim_feedforward: 全连接层的神经元的数目，又称经过此层输入的维度（Default = 2048）\n",
    "        dropout: dropout的概率（Default = 0.1）\n",
    "        activation: 两个线性层中间的激活函数，默认relu或gelu\n",
    "        custom_encoder: 自定义encoder（Default=None）\n",
    "        custom_decoder: 自定义decoder（Default=None）\n",
    "        lay_norm_eps: layer normalization中的微小量，防止分母为0（Default = 1e-5）\n",
    "        batch_first: 若`True`，则为(batch, seq, feture)，若为`False`，则为(seq, batch, feature)（Default：False）\n",
    "    \n",
    "    例子：\n",
    "        >>> transformer_model = Transformer(nhead=16, num_encoder_layers=12)\n",
    "        >>> src = torch.rand((10, 32, 512))\n",
    "        >>> tgt = torch.rand((20, 32, 512))\n",
    "        >>> out = transformer_model(src, tgt)\n",
    "    '''\n",
    "    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,\n",
    "                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,\n",
    "                 activation = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,\n",
    "                 layer_norm_eps: float = 1e-5, batch_first: bool = False) -> None:\n",
    "        super(Transformer, self).__init__()\n",
    "        if custom_encoder is not None:\n",
    "            self.encoder = custom_encoder\n",
    "        else:\n",
    "            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,\n",
    "                                                    activation, layer_norm_eps, batch_first)\n",
    "            encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)\n",
    "\n",
    "        if custom_decoder is not None:\n",
    "            self.decoder = custom_decoder\n",
    "        else:\n",
    "            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,\n",
    "                                                    activation, layer_norm_eps, batch_first)\n",
    "            decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)\n",
    "            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)\n",
    "\n",
    "        self._reset_parameters()\n",
    "\n",
    "        self.d_model = d_model\n",
    "        self.nhead = nhead\n",
    "\n",
    "        self.batch_first = batch_first\n",
    "\n",
    "    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,\n",
    "                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,\n",
    "                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n",
    "        r'''\n",
    "        参数：\n",
    "            src: 源语言序列（送入Encoder）（必备）\n",
    "            tgt: 目标语言序列（送入Decoder）（必备）\n",
    "            src_mask: （可选)\n",
    "            tgt_mask: （可选）\n",
    "            memory_mask: （可选）\n",
    "            src_key_padding_mask: （可选）\n",
    "            tgt_key_padding_mask: （可选）\n",
    "            memory_key_padding_mask: （可选）\n",
    "        \n",
    "        形状：\n",
    "            - src: shape:`(S, N, E)`, `(N, S, E)` if batch_first.\n",
    "            - tgt: shape:`(T, N, E)`, `(N, T, E)` if batch_first.\n",
    "            - src_mask: shape:`(S, S)`.\n",
    "            - tgt_mask: shape:`(T, T)`.\n",
    "            - memory_mask: shape:`(T, S)`.\n",
    "            - src_key_padding_mask: shape:`(N, S)`.\n",
    "            - tgt_key_padding_mask: shape:`(N, T)`.\n",
    "            - memory_key_padding_mask: shape:`(N, S)`.\n",
    "\n",
    "            [src/tgt/memory]_mask确保有些位置不被看到，如做decode的时候，只能看该位置及其以前的，而不能看后面的。\n",
    "            若为ByteTensor，非0的位置会被忽略不做注意力；若为BoolTensor，True对应的位置会被忽略；\n",
    "            若为数值，则会直接加到attn_weights\n",
    "\n",
    "            [src/tgt/memory]_key_padding_mask 使得key里面的某些元素不参与attention计算，三种情况同上\n",
    "\n",
    "            - output: shape:`(T, N, E)`, `(N, T, E)` if batch_first.\n",
    "\n",
    "        注意：\n",
    "            src和tgt的最后一维需要等于d_model，batch的那一维需要相等\n",
    "            \n",
    "        例子:\n",
    "            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)\n",
    "        '''\n",
    "        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)\n",
    "        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,\n",
    "                              tgt_key_padding_mask=tgt_key_padding_mask,\n",
    "                              memory_key_padding_mask=memory_key_padding_mask)\n",
    "        return output\n",
    "        \n",
    "    def generate_square_subsequent_mask(self, sz: int) -> Tensor:\n",
    "        r'''产生关于序列的mask，被遮住的区域赋值`-inf`，未被遮住的区域赋值为`0`'''\n",
    "        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n",
    "        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
    "        return mask\n",
    "\n",
    "    def _reset_parameters(self):\n",
    "        r'''用正态分布初始化参数'''\n",
    "        for p in self.parameters():\n",
    "            if p.dim() > 1:\n",
    "                xavier_uniform_(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 32, 512])\n"
     ]
    }
   ],
   "source": [
    "# 小例子\n",
    "transformer_model = Transformer(nhead=16, num_encoder_layers=12)\n",
    "src = torch.rand((10, 32, 512))\n",
    "tgt = torch.rand((20, 32, 512))\n",
    "out = transformer_model(src, tgt)\n",
    "print(out.shape)\n",
    "# torch.Size([20, 32, 512])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "到此为止，PyTorch的Transformer库我们已经全部实现，相比于官方的版本，手写的这个少了较多的判定语句。\n",
    "## 致谢\n",
    "本文由台运鹏撰写，本项目成员重新组织和整理。最后，期待您的阅读反馈和star，谢谢。"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": ""
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}